diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index bca9fc4ab373..1d5bd57cf8e0 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -54,7 +54,7 @@ class UNet3DConditionOutput(BaseOutput): The output of [`UNet3DConditionModel`]. Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ @@ -74,9 +74,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. @@ -87,8 +87,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. """ @@ -533,7 +533,7 @@ def forward( Args: sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.