@@ -648,25 +648,35 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap
648648 logger .info (f"LoRA+ Text Encoder LR Ratio: { self .loraplus_text_encoder_lr_ratio or self .loraplus_lr_ratio } " )
649649
650650 def prepare_optimizer_params (self , text_encoder_lr , unet_lr ):
651- print (f"--- [DEBUG] prepare_optimizer_params: num unet loras = { len (self .unet_loras )} " )
652-
651+ # set requires_grad for all parameters
652+ for lora in self .text_encoder_loras + self .unet_loras :
653+ if hasattr (lora , 'lora_down' ):
654+ lora .lora_down .requires_grad_ (False )
655+ if hasattr (lora , 'lora_mid' ):
656+ lora .lora_mid .requires_grad_ (True )
657+ if hasattr (lora , 'lora_up' ):
658+ lora .lora_up .requires_grad_ (True )
659+
653660 def get_params (loras , lr ):
654661 if lr is None :
655662 return []
656663
657664 params = []
658665 for lora in loras :
659666 if hasattr (lora , 'lora_mid' ) and hasattr (lora , 'lora_up' ):
660- params .extend ([
661- {"params" : list (lora .lora_mid .parameters ()), "lr" : lr },
662- {"params" : list (lora .lora_up .parameters ()), "lr" : lr },
663- ])
667+ # Filter out parameters that don't require gradients
668+ mid_params = [p for p in lora .lora_mid .parameters () if p .requires_grad ]
669+ up_params = [p for p in lora .lora_up .parameters () if p .requires_grad ]
670+
671+ if mid_params :
672+ params .append ({"params" : mid_params , "lr" : lr })
673+ if up_params :
674+ params .append ({"params" : up_params , "lr" : lr })
664675 return params
665676
666677 text_encoder_params = get_params (self .text_encoder_loras , text_encoder_lr )
667678 unet_params = get_params (self .unet_loras , unet_lr )
668679
669- print (f"--- [DEBUG] prepare_optimizer_params: generated { len (unet_params )} unet param groups" )
670680 return text_encoder_params + unet_params
671681
672682 def enable_gradient_checkpointing (self ):
0 commit comments