@@ -95,8 +95,13 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
9595 raise ValueError ("diff_output_preservation requires a network to be set" )
9696 if self .train_config .train_text_encoder :
9797 raise ValueError ("diff_output_preservation is not supported with train_text_encoder" )
98-
99- # always do a prior prediction when doing diff output preservation
98+
99+ if self .train_config .blank_prompt_preservation :
100+ if self .network_config is None :
101+ raise ValueError ("blank_prompt_preservation requires a network to be set" )
102+
103+ if self .train_config .blank_prompt_preservation or self .train_config .diff_output_preservation :
104+ # always do a prior prediction when doing output preservation
100105 self .do_prior_prediction = True
101106
102107 # store the loss target for a batch so we can use it in a loss
@@ -372,6 +377,13 @@ def hook_before_train_loop(self):
372377 self .sd .text_encoder_to ("cpu" )
373378 flush ()
374379
380+ if self .train_config .blank_prompt_preservation and self .cached_blank_embeds is None :
381+ # make sure we have this if not unloading
382+ self .cached_blank_embeds = self .sd .encode_prompt ("" ).to (
383+ self .device_torch ,
384+ dtype = self .sd .torch_dtype
385+ ).detach ()
386+
375387 if self .train_config .diffusion_feature_extractor_path is not None :
376388 vae = self .sd .vae
377389 # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -634,33 +646,28 @@ def calculate_loss(
634646 stepped_latents = torch .cat (stepped_chunks , dim = 0 )
635647
636648 stepped_latents = stepped_latents .to (self .sd .vae .device , dtype = self .sd .vae .dtype )
637- # resize to half the size of the latents
638- stepped_latents_half = torch .nn .functional .interpolate (
639- stepped_latents ,
640- size = (stepped_latents .shape [2 ] // 2 , stepped_latents .shape [3 ] // 2 ),
641- mode = 'bilinear' ,
642- align_corners = False
643- )
644- pred_features = self .dfe (stepped_latents .float ())
645- pred_features_half = self .dfe (stepped_latents_half .float ())
649+ sl = stepped_latents
650+ if len (sl .shape ) == 5 :
651+ # video B,C,T,H,W
652+ sl = sl .permute (0 , 2 , 1 , 3 , 4 ) # B,T,C,H,W
653+ b , t , c , h , w = sl .shape
654+ sl = sl .reshape (b * t , c , h , w )
655+ pred_features = self .dfe (sl .float ())
646656 with torch .no_grad ():
647- target_features = self . dfe ( batch .latents . to ( self . device_torch , dtype = torch . float32 ))
648- batch_latents_half = torch . nn . functional . interpolate (
649- batch . latents . to ( self . device_torch , dtype = torch . float32 ),
650- size = ( batch . latents . shape [ 2 ] // 2 , batch . latents . shape [ 3 ] // 2 ),
651- mode = 'bilinear' ,
652- align_corners = False
653- )
654- target_features_half = self .dfe (batch_latents_half )
657+ bl = batch .latents
658+ bl = bl . to ( self . sd . vae . device )
659+ if len ( bl . shape ) == 5 :
660+ # video B,C,T,H,W
661+ bl = bl . permute ( 0 , 2 , 1 , 3 , 4 ) # B,T,C,H,W
662+ b , t , c , h , w = bl . shape
663+ bl = bl . reshape ( b * t , c , h , w )
664+ target_features = self .dfe (bl . float () )
655665 # scale dfe so it is weaker at higher noise levels
656666 dfe_scaler = 1 - (timesteps .float () / 1000.0 ).view (- 1 , 1 , 1 , 1 ).to (self .device_torch )
657667
658668 dfe_loss = torch .nn .functional .mse_loss (pred_features , target_features , reduction = "none" ) * \
659669 self .train_config .diffusion_feature_extractor_weight * dfe_scaler
660-
661- dfe_loss_half = torch .nn .functional .mse_loss (pred_features_half , target_features_half , reduction = "none" ) * \
662- self .train_config .diffusion_feature_extractor_weight * dfe_scaler
663- additional_loss += dfe_loss .mean () + dfe_loss_half .mean ()
670+ additional_loss += dfe_loss .mean ()
664671 elif self .dfe .version == 2 :
665672 # version 2
666673 # do diffusion feature extraction on target
@@ -1798,6 +1805,14 @@ def get_adapter_multiplier():
17981805 if self .train_config .diff_output_preservation :
17991806 prior_embeds_to_use = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
18001807
1808+ if self .train_config .blank_prompt_preservation :
1809+ blank_embeds = self .cached_blank_embeds .clone ().detach ().to (
1810+ self .device_torch , dtype = dtype
1811+ )
1812+ prior_embeds_to_use = concat_prompt_embeds (
1813+ [blank_embeds ] * noisy_latents .shape [0 ]
1814+ )
1815+
18011816 prior_pred = self .get_prior_prediction (
18021817 noisy_latents = noisy_latents ,
18031818 conditional_embeds = prior_embeds_to_use ,
@@ -1973,7 +1988,8 @@ def get_adapter_multiplier():
19731988 prior_to_calculate_loss = prior_pred
19741989 # if we are doing diff_output_preservation and not noing inverted masked prior
19751990 # then we need to send none here so it will not target the prior
1976- if self .train_config .diff_output_preservation and not do_inverted_masked_prior :
1991+ doing_preservation = self .train_config .diff_output_preservation or self .train_config .blank_prompt_preservation
1992+ if doing_preservation and not do_inverted_masked_prior :
19771993 prior_to_calculate_loss = None
19781994
19791995 loss = self .calculate_loss (
@@ -1986,24 +2002,34 @@ def get_adapter_multiplier():
19862002 prior_pred = prior_to_calculate_loss ,
19872003 )
19882004
1989- if self .train_config .diff_output_preservation :
2005+ if self .train_config .diff_output_preservation or self . train_config . blank_prompt_preservation :
19902006 # send the loss backwards otherwise checkpointing will fail
19912007 self .accelerator .backward (loss )
19922008 normal_loss = loss .detach () # dont send backward again
19932009
1994- dop_embeds = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
1995- dop_pred = self .predict_noise (
2010+ with torch .no_grad ():
2011+ if self .train_config .diff_output_preservation :
2012+ preservation_embeds = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
2013+ elif self .train_config .blank_prompt_preservation :
2014+ blank_embeds = self .cached_blank_embeds .clone ().detach ().to (
2015+ self .device_torch , dtype = dtype
2016+ )
2017+ preservation_embeds = concat_prompt_embeds (
2018+ [blank_embeds ] * noisy_latents .shape [0 ]
2019+ )
2020+ preservation_pred = self .predict_noise (
19962021 noisy_latents = noisy_latents .to (self .device_torch , dtype = dtype ),
19972022 timesteps = timesteps ,
1998- conditional_embeds = dop_embeds .to (self .device_torch , dtype = dtype ),
2023+ conditional_embeds = preservation_embeds .to (self .device_torch , dtype = dtype ),
19992024 unconditional_embeds = unconditional_embeds ,
20002025 batch = batch ,
20012026 ** pred_kwargs
20022027 )
2003- dop_loss = torch .nn .functional .mse_loss (dop_pred , prior_pred ) * self .train_config .diff_output_preservation_multiplier
2004- self .accelerator .backward (dop_loss )
2005-
2006- loss = normal_loss + dop_loss
2028+ multiplier = self .train_config .diff_output_preservation_multiplier if self .train_config .diff_output_preservation else self .train_config .blank_prompt_preservation_multiplier
2029+ preservation_loss = torch .nn .functional .mse_loss (preservation_pred , prior_pred ) * multiplier
2030+ self .accelerator .backward (preservation_loss )
2031+
2032+ loss = normal_loss + preservation_loss
20072033 loss = loss .clone ().detach ()
20082034 # require grad again so the backward wont fail
20092035 loss .requires_grad_ (True )
0 commit comments