Skip to content

Commit d2fc5eb

Browse files
authored
[Refactor] FreeInit for AnimateDiff based pipelines (#6874)
* update * update * update * update * update * update * update * update * update * update
1 parent 779eef9 commit d2fc5eb

File tree

7 files changed

+398
-595
lines changed

7 files changed

+398
-595
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 64 additions & 312 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
3636
from ...utils.torch_utils import randn_tensor
37+
from ..free_init_utils import FreeInitMixin
3738
from ..pipeline_utils import DiffusionPipeline
3839
from .pipeline_output import AnimateDiffPipelineOutput
3940

@@ -163,7 +164,9 @@ def retrieve_timesteps(
163164
return timesteps, num_inference_steps
164165

165166

166-
class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
167+
class AnimateDiffVideoToVideoPipeline(
168+
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin
169+
):
167170
r"""
168171
Pipeline for video-to-video generation.
169172
@@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
193196
"""
194197

195198
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
196-
_optional_components = ["feature_extractor", "image_encoder"]
199+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
197200
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
198201

199202
def __init__(
@@ -215,7 +218,8 @@ def __init__(
215218
image_encoder: CLIPVisionModelWithProjection = None,
216219
):
217220
super().__init__()
218-
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
221+
if isinstance(unet, UNet2DConditionModel):
222+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
219223

220224
self.register_modules(
221225
vae=vae,
@@ -584,12 +588,12 @@ def check_inputs(
584588
if video is not None and latents is not None:
585589
raise ValueError("Only one of `video` or `latents` should be provided")
586590

587-
def get_timesteps(self, num_inference_steps, strength, device):
591+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
588592
# get the original timestep using init_timestep
589593
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
590594

591595
t_start = max(num_inference_steps - init_timestep, 0)
592-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
596+
timesteps = timesteps[t_start * self.scheduler.order :]
593597

594598
return timesteps, num_inference_steps - t_start
595599

@@ -876,9 +880,8 @@ def __call__(
876880

877881
# 4. Prepare timesteps
878882
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
879-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
883+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
880884
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
881-
self._num_timesteps = len(timesteps)
882885

883886
# 5. Prepare latent variables
884887
num_channels_latents = self.unet.config.in_channels
@@ -901,42 +904,55 @@ def __call__(
901904
# 7. Add image embeds for IP-Adapter
902905
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
903906

904-
# 8. Denoising loop
905-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
906-
with self.progress_bar(total=num_inference_steps) as progress_bar:
907-
for i, t in enumerate(timesteps):
908-
# expand the latents if we are doing classifier free guidance
909-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
910-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
911-
912-
# predict the noise residual
913-
noise_pred = self.unet(
914-
latent_model_input,
915-
t,
916-
encoder_hidden_states=prompt_embeds,
917-
cross_attention_kwargs=self.cross_attention_kwargs,
918-
added_cond_kwargs=added_cond_kwargs,
919-
).sample
920-
921-
# perform guidance
922-
if self.do_classifier_free_guidance:
923-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
924-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
925-
926-
# compute the previous noisy sample x_t -> x_t-1
927-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
928-
929-
if callback_on_step_end is not None:
930-
callback_kwargs = {}
931-
for k in callback_on_step_end_tensor_inputs:
932-
callback_kwargs[k] = locals()[k]
933-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
934-
935-
latents = callback_outputs.pop("latents", latents)
936-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
937-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
938-
939-
progress_bar.update()
907+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
908+
for free_init_iter in range(num_free_init_iters):
909+
if self.free_init_enabled:
910+
latents, timesteps = self._apply_free_init(
911+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
912+
)
913+
num_inference_steps = len(timesteps)
914+
# make sure to readjust timesteps based on strength
915+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
916+
917+
self._num_timesteps = len(timesteps)
918+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
919+
# 8. Denoising loop
920+
with self.progress_bar(total=num_inference_steps) as progress_bar:
921+
for i, t in enumerate(timesteps):
922+
# expand the latents if we are doing classifier free guidance
923+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
924+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
925+
926+
# predict the noise residual
927+
noise_pred = self.unet(
928+
latent_model_input,
929+
t,
930+
encoder_hidden_states=prompt_embeds,
931+
cross_attention_kwargs=self.cross_attention_kwargs,
932+
added_cond_kwargs=added_cond_kwargs,
933+
).sample
934+
935+
# perform guidance
936+
if self.do_classifier_free_guidance:
937+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
938+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
939+
940+
# compute the previous noisy sample x_t -> x_t-1
941+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
942+
943+
if callback_on_step_end is not None:
944+
callback_kwargs = {}
945+
for k in callback_on_step_end_tensor_inputs:
946+
callback_kwargs[k] = locals()[k]
947+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
948+
949+
latents = callback_outputs.pop("latents", latents)
950+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
951+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
952+
953+
# call the callback, if provided
954+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
955+
progress_bar.update()
940956

941957
if output_type == "latent":
942958
return AnimateDiffPipelineOutput(frames=latents)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Tuple, Union
17+
18+
import torch
19+
import torch.fft as fft
20+
21+
from ..utils.torch_utils import randn_tensor
22+
23+
24+
class FreeInitMixin:
25+
r"""Mixin class for FreeInit."""
26+
27+
def enable_free_init(
28+
self,
29+
num_iters: int = 3,
30+
use_fast_sampling: bool = False,
31+
method: str = "butterworth",
32+
order: int = 4,
33+
spatial_stop_frequency: float = 0.25,
34+
temporal_stop_frequency: float = 0.25,
35+
):
36+
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
37+
38+
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
39+
40+
Args:
41+
num_iters (`int`, *optional*, defaults to `3`):
42+
Number of FreeInit noise re-initialization iterations.
43+
use_fast_sampling (`bool`, *optional*, defaults to `False`):
44+
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
45+
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
46+
method (`str`, *optional*, defaults to `butterworth`):
47+
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
48+
FreeInit low pass filter.
49+
order (`int`, *optional*, defaults to `4`):
50+
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
51+
whereas lower values lead to `gaussian` method behaviour.
52+
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
53+
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
54+
the original implementation.
55+
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
56+
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
57+
the original implementation.
58+
"""
59+
self._free_init_num_iters = num_iters
60+
self._free_init_use_fast_sampling = use_fast_sampling
61+
self._free_init_method = method
62+
self._free_init_order = order
63+
self._free_init_spatial_stop_frequency = spatial_stop_frequency
64+
self._free_init_temporal_stop_frequency = temporal_stop_frequency
65+
66+
def disable_free_init(self):
67+
"""Disables the FreeInit mechanism if enabled."""
68+
self._free_init_num_iters = None
69+
70+
@property
71+
def free_init_enabled(self):
72+
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
73+
74+
def _get_free_init_freq_filter(
75+
self,
76+
shape: Tuple[int, ...],
77+
device: Union[str, torch.dtype],
78+
filter_type: str,
79+
order: float,
80+
spatial_stop_frequency: float,
81+
temporal_stop_frequency: float,
82+
) -> torch.Tensor:
83+
r"""Returns the FreeInit filter based on filter type and other input conditions."""
84+
85+
time, height, width = shape[-3], shape[-2], shape[-1]
86+
mask = torch.zeros(shape)
87+
88+
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
89+
return mask
90+
91+
if filter_type == "butterworth":
92+
93+
def retrieve_mask(x):
94+
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
95+
elif filter_type == "gaussian":
96+
97+
def retrieve_mask(x):
98+
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
99+
elif filter_type == "ideal":
100+
101+
def retrieve_mask(x):
102+
return 1 if x <= spatial_stop_frequency * 2 else 0
103+
else:
104+
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
105+
106+
for t in range(time):
107+
for h in range(height):
108+
for w in range(width):
109+
d_square = (
110+
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
111+
+ (2 * h / height - 1) ** 2
112+
+ (2 * w / width - 1) ** 2
113+
)
114+
mask[..., t, h, w] = retrieve_mask(d_square)
115+
116+
return mask.to(device)
117+
118+
def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
119+
r"""Noise reinitialization."""
120+
# FFT
121+
x_freq = fft.fftn(x, dim=(-3, -2, -1))
122+
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
123+
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
124+
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
125+
126+
# frequency mix
127+
high_pass_filter = 1 - low_pass_filter
128+
x_freq_low = x_freq * low_pass_filter
129+
noise_freq_high = noise_freq * high_pass_filter
130+
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
131+
132+
# IFFT
133+
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
134+
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
135+
136+
return x_mixed
137+
138+
def _apply_free_init(
139+
self,
140+
latents: torch.Tensor,
141+
free_init_iteration: int,
142+
num_inference_steps: int,
143+
device: torch.device,
144+
dtype: torch.dtype,
145+
generator: torch.Generator,
146+
):
147+
if free_init_iteration == 0:
148+
self._free_init_initial_noise = latents.detach().clone()
149+
return latents, self.scheduler.timesteps
150+
151+
latent_shape = latents.shape
152+
153+
free_init_filter_shape = (1, *latent_shape[1:])
154+
free_init_freq_filter = self._get_free_init_freq_filter(
155+
shape=free_init_filter_shape,
156+
device=device,
157+
filter_type=self._free_init_method,
158+
order=self._free_init_order,
159+
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
160+
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
161+
)
162+
163+
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
164+
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
165+
166+
z_t = self.scheduler.add_noise(
167+
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
168+
).to(dtype=torch.float32)
169+
170+
z_rand = randn_tensor(
171+
shape=latent_shape,
172+
generator=generator,
173+
device=device,
174+
dtype=torch.float32,
175+
)
176+
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
177+
latents = latents.to(dtype)
178+
179+
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
180+
if self._free_init_use_fast_sampling:
181+
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
182+
self.scheduler.set_timesteps(num_inference_steps, device=device)
183+
184+
return latents, self.scheduler.timesteps

0 commit comments

Comments
 (0)