34
34
)
35
35
from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
36
36
from ...utils .torch_utils import randn_tensor
37
+ from ..free_init_utils import FreeInitMixin
37
38
from ..pipeline_utils import DiffusionPipeline
38
39
from .pipeline_output import AnimateDiffPipelineOutput
39
40
@@ -163,7 +164,9 @@ def retrieve_timesteps(
163
164
return timesteps , num_inference_steps
164
165
165
166
166
- class AnimateDiffVideoToVideoPipeline (DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin ):
167
+ class AnimateDiffVideoToVideoPipeline (
168
+ DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin , FreeInitMixin
169
+ ):
167
170
r"""
168
171
Pipeline for video-to-video generation.
169
172
@@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
193
196
"""
194
197
195
198
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
196
- _optional_components = ["feature_extractor" , "image_encoder" ]
199
+ _optional_components = ["feature_extractor" , "image_encoder" , "motion_adapter" ]
197
200
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
198
201
199
202
def __init__ (
@@ -215,7 +218,8 @@ def __init__(
215
218
image_encoder : CLIPVisionModelWithProjection = None ,
216
219
):
217
220
super ().__init__ ()
218
- unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
221
+ if isinstance (unet , UNet2DConditionModel ):
222
+ unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
219
223
220
224
self .register_modules (
221
225
vae = vae ,
@@ -584,12 +588,12 @@ def check_inputs(
584
588
if video is not None and latents is not None :
585
589
raise ValueError ("Only one of `video` or `latents` should be provided" )
586
590
587
- def get_timesteps (self , num_inference_steps , strength , device ):
591
+ def get_timesteps (self , num_inference_steps , timesteps , strength , device ):
588
592
# get the original timestep using init_timestep
589
593
init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
590
594
591
595
t_start = max (num_inference_steps - init_timestep , 0 )
592
- timesteps = self . scheduler . timesteps [t_start * self .scheduler .order :]
596
+ timesteps = timesteps [t_start * self .scheduler .order :]
593
597
594
598
return timesteps , num_inference_steps - t_start
595
599
@@ -876,9 +880,8 @@ def __call__(
876
880
877
881
# 4. Prepare timesteps
878
882
timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
879
- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
883
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , timesteps , strength , device )
880
884
latent_timestep = timesteps [:1 ].repeat (batch_size * num_videos_per_prompt )
881
- self ._num_timesteps = len (timesteps )
882
885
883
886
# 5. Prepare latent variables
884
887
num_channels_latents = self .unet .config .in_channels
@@ -901,42 +904,55 @@ def __call__(
901
904
# 7. Add image embeds for IP-Adapter
902
905
added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
903
906
904
- # 8. Denoising loop
905
- num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
906
- with self .progress_bar (total = num_inference_steps ) as progress_bar :
907
- for i , t in enumerate (timesteps ):
908
- # expand the latents if we are doing classifier free guidance
909
- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
910
- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
911
-
912
- # predict the noise residual
913
- noise_pred = self .unet (
914
- latent_model_input ,
915
- t ,
916
- encoder_hidden_states = prompt_embeds ,
917
- cross_attention_kwargs = self .cross_attention_kwargs ,
918
- added_cond_kwargs = added_cond_kwargs ,
919
- ).sample
920
-
921
- # perform guidance
922
- if self .do_classifier_free_guidance :
923
- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
924
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
925
-
926
- # compute the previous noisy sample x_t -> x_t-1
927
- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
928
-
929
- if callback_on_step_end is not None :
930
- callback_kwargs = {}
931
- for k in callback_on_step_end_tensor_inputs :
932
- callback_kwargs [k ] = locals ()[k ]
933
- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
934
-
935
- latents = callback_outputs .pop ("latents" , latents )
936
- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
937
- negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
938
-
939
- progress_bar .update ()
907
+ num_free_init_iters = self ._free_init_num_iters if self .free_init_enabled else 1
908
+ for free_init_iter in range (num_free_init_iters ):
909
+ if self .free_init_enabled :
910
+ latents , timesteps = self ._apply_free_init (
911
+ latents , free_init_iter , num_inference_steps , device , latents .dtype , generator
912
+ )
913
+ num_inference_steps = len (timesteps )
914
+ # make sure to readjust timesteps based on strength
915
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , timesteps , strength , device )
916
+
917
+ self ._num_timesteps = len (timesteps )
918
+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
919
+ # 8. Denoising loop
920
+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
921
+ for i , t in enumerate (timesteps ):
922
+ # expand the latents if we are doing classifier free guidance
923
+ latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
924
+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
925
+
926
+ # predict the noise residual
927
+ noise_pred = self .unet (
928
+ latent_model_input ,
929
+ t ,
930
+ encoder_hidden_states = prompt_embeds ,
931
+ cross_attention_kwargs = self .cross_attention_kwargs ,
932
+ added_cond_kwargs = added_cond_kwargs ,
933
+ ).sample
934
+
935
+ # perform guidance
936
+ if self .do_classifier_free_guidance :
937
+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
938
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
939
+
940
+ # compute the previous noisy sample x_t -> x_t-1
941
+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
942
+
943
+ if callback_on_step_end is not None :
944
+ callback_kwargs = {}
945
+ for k in callback_on_step_end_tensor_inputs :
946
+ callback_kwargs [k ] = locals ()[k ]
947
+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
948
+
949
+ latents = callback_outputs .pop ("latents" , latents )
950
+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
951
+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
952
+
953
+ # call the callback, if provided
954
+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
955
+ progress_bar .update ()
940
956
941
957
if output_type == "latent" :
942
958
return AnimateDiffPipelineOutput (frames = latents )
0 commit comments