Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 64 additions & 312 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -163,7 +164,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
class AnimateDiffVideoToVideoPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin
):
r"""
Pipeline for video-to-video generation.

Expand Down Expand Up @@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
"""

model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

def __init__(
Expand All @@ -215,7 +218,8 @@ def __init__(
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
if isinstance(unet, UNet2DConditionModel):
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)

self.register_modules(
vae=vae,
Expand Down Expand Up @@ -584,12 +588,12 @@ def check_inputs(
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")

def get_timesteps(self, num_inference_steps, strength, device):
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
timesteps = timesteps[t_start * self.scheduler.order :]

return timesteps, num_inference_steps - t_start

Expand Down Expand Up @@ -876,9 +880,8 @@ def __call__(

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand All @@ -901,42 +904,55 @@ def __call__(
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

progress_bar.update()
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
for free_init_iter in range(num_free_init_iters):
if self.free_init_enabled:
latents, timesteps = self._apply_free_init(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)
num_inference_steps = len(timesteps)
# make sure to readjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
Expand Down
184 changes: 184 additions & 0 deletions src/diffusers/pipelines/free_init_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple, Union

import torch
import torch.fft as fft

from ..utils.torch_utils import randn_tensor


class FreeInitMixin:
r"""Mixin class for FreeInit."""

def enable_free_init(
self,
num_iters: int = 3,
use_fast_sampling: bool = False,
method: str = "butterworth",
order: int = 4,
spatial_stop_frequency: float = 0.25,
temporal_stop_frequency: float = 0.25,
):
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
Args:
num_iters (`int`, *optional*, defaults to `3`):
Number of FreeInit noise re-initialization iterations.
use_fast_sampling (`bool`, *optional*, defaults to `False`):
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
method (`str`, *optional*, defaults to `butterworth`):
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
FreeInit low pass filter.
order (`int`, *optional*, defaults to `4`):
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
whereas lower values lead to `gaussian` method behaviour.
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
the original implementation.
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
the original implementation.
"""
self._free_init_num_iters = num_iters
self._free_init_use_fast_sampling = use_fast_sampling
self._free_init_method = method
self._free_init_order = order
self._free_init_spatial_stop_frequency = spatial_stop_frequency
self._free_init_temporal_stop_frequency = temporal_stop_frequency

def disable_free_init(self):
"""Disables the FreeInit mechanism if enabled."""
self._free_init_num_iters = None

@property
def free_init_enabled(self):
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None

def _get_free_init_freq_filter(
self,
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""

time, height, width = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)

if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask

if filter_type == "butterworth":

def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":

def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":

def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")

for t in range(time):
for h in range(height):
for w in range(width):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
+ (2 * h / height - 1) ** 2
+ (2 * w / width - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)

return mask.to(device)

def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))

# frequency mix
high_pass_filter = 1 - low_pass_filter
x_freq_low = x_freq * low_pass_filter
noise_freq_high = noise_freq * high_pass_filter
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain

# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real

return x_mixed

def _apply_free_init(
self,
latents: torch.Tensor,
free_init_iteration: int,
num_inference_steps: int,
device: torch.device,
dtype: torch.dtype,
generator: torch.Generator,
):
if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 This is incorrect and seems like a regression from the old implementation. I was trying to debug why AnimateLCM was failing to produce good results and stumbled upon this other issue (it does produce good results btw except for when use_fast_sampling==False. setting it to True seems to give good results).

Copy the FreeInit code from here and execute. You will see that the first iteration runs for 20 steps, second iteration runs for 13 steps and third iteration runs for 20 steps. This is incorrect because when use_fast_sampling=True, it should be 7, 13 and 20 but we return here without the fast sampling check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Could I open a PR fixing this behavior since this has been merged already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @a-r-r-o-w missed this. Yes please feel free to open a PR.


latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)

# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
self.scheduler.set_timesteps(num_inference_steps, device=device)

return latents, self.scheduler.timesteps
Loading