Skip to content

Commit 4da1275

Browse files
committed
w
1 parent 2da1bbf commit 4da1275

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

networks/meta_lora_flux.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -648,25 +648,35 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap
648648
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
649649

650650
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
651-
print(f"--- [DEBUG] prepare_optimizer_params: num unet loras = {len(self.unet_loras)}")
652-
651+
# set requires_grad for all parameters
652+
for lora in self.text_encoder_loras + self.unet_loras:
653+
if hasattr(lora, 'lora_down'):
654+
lora.lora_down.requires_grad_(False)
655+
if hasattr(lora, 'lora_mid'):
656+
lora.lora_mid.requires_grad_(True)
657+
if hasattr(lora, 'lora_up'):
658+
lora.lora_up.requires_grad_(True)
659+
653660
def get_params(loras, lr):
654661
if lr is None:
655662
return []
656663

657664
params = []
658665
for lora in loras:
659666
if hasattr(lora, 'lora_mid') and hasattr(lora, 'lora_up'):
660-
params.extend([
661-
{"params": list(lora.lora_mid.parameters()), "lr": lr},
662-
{"params": list(lora.lora_up.parameters()), "lr": lr},
663-
])
667+
# Filter out parameters that don't require gradients
668+
mid_params = [p for p in lora.lora_mid.parameters() if p.requires_grad]
669+
up_params = [p for p in lora.lora_up.parameters() if p.requires_grad]
670+
671+
if mid_params:
672+
params.append({"params": mid_params, "lr": lr})
673+
if up_params:
674+
params.append({"params": up_params, "lr": lr})
664675
return params
665676

666677
text_encoder_params = get_params(self.text_encoder_loras, text_encoder_lr)
667678
unet_params = get_params(self.unet_loras, unet_lr)
668679

669-
print(f"--- [DEBUG] prepare_optimizer_params: generated {len(unet_params)} unet param groups")
670680
return text_encoder_params + unet_params
671681

672682
def enable_gradient_checkpointing(self):

0 commit comments

Comments
 (0)